Example-21: Periodic fixed points (4D accelerator mapping)
[1]:
# In this example fixed point computation is illustrated for 4D mappings
[2]:
# Import
import numpy
import jax
from jax import jit
from jax import vmap
jax.numpy.set_printoptions(linewidth=256)
# Test symplectic mapping and corresponding inverse
from tohubohu.util import forward4D
from tohubohu.util import inverse4D
# REM factory
from tohubohu import rem
# Fixed point
from tohubohu import iterate
from tohubohu import prime
from tohubohu import unique
from tohubohu import chain
from tohubohu import monodromy
from tohubohu import combine
from tohubohu import classify
# Iteration
from tohubohu import nest
from tohubohu import nest_list
# Plotting
import plotly.graph_objects as go
[3]:
# Set data type
jax.config.update("jax_enable_x64", True)
[4]:
# Set device
device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[5]:
# Set mapping parameters
nux, nuy = 0.168, 0.201
mux, muy = 2*jax.numpy.pi*nux, 2*jax.numpy.pi*nuy
cx, sx, cy, sy = jax.numpy.cos(mux), jax.numpy.sin(mux), jax.numpy.cos(muy), jax.numpy.sin(muy)
mu = 0.0
k = jax.numpy.asarray([cx, sx, cy, sy, mu])
x = jax.numpy.array([0.0, 0.0, 0.0, 0.0])
print(forward4D(x, k))
print(inverse4D(x, k))
[0. 0. 0. 0.]
[0. 0. 0. 0.]
[6]:
# Set initial points for fixed point computation
size = 1024
seed = jax.random.PRNGKey(0)
qxs = jax.random.uniform(seed, shape=(size, ), minval=-5, maxval=5)
seed = jax.random.PRNGKey(1)
pxs = jax.random.uniform(seed, shape=(size, ), minval=-5, maxval=5)
seed = jax.random.PRNGKey(0)
qys = jax.random.uniform(seed, shape=(size, ), minval=-5, maxval=5)
seed = jax.random.PRNGKey(1)
pys = jax.random.uniform(seed, shape=(size, ), minval=-5, maxval=5)
xs = jax.numpy.stack([qxs, qys, pxs, pys]).T
print(xs.shape)
(1024, 4)
[7]:
# Search for fixed points
order = 5
limit = 32
solver = jit(vmap(iterate(32, forward4D, order=order), (0, None)))
points = solver(xs, k)
points = points[~jax.numpy.any(jax.numpy.isnan(points), -1)]
# Select prime points
mask = jit(vmap(prime(forward4D, order=order, rtol=1.0E-9, atol=1.0E-9), (0, None)))(points, k)
points = points[mask]
# Select unique chains
mask = unique(order, forward4D, points, k, tol=1.0E-6)
points = points[mask]
# Generate chains
chains = jax.numpy.vstack(jit(vmap(chain(order, forward4D), (0, None)))(points, k))
# Result
print(points)
[[-0.00882392 0.07952321 0.00514389 0.1087343 ]
[ 0.10033137 -0.6373762 0.54186547 0.07477984]
[-0.00882392 -0.07952321 0.00514389 -0.1087343 ]
[ 0.10033137 0.6373762 0.54186547 -0.07477984]
[-0.13885106 -0.5302854 0.61644478 -0.16862399]]
[8]:
# Test fixed points
for point in points:
print(point.reshape(1, -1))
print(nest_list(order, forward4D)(point, k))
print()
[[-0.00882392 0.07952321 0.00514389 0.1087343 ]]
[[-5.30688975e-03 1.29057305e-01 7.13535543e-03 -4.24083616e-02]
[-1.08748682e-02 1.07182682e-10 -5.91313832e-05 -1.35425081e-01]
[-5.30688976e-03 -1.29057304e-01 9.49226933e-03 -4.10385760e-02]
[-8.82392005e-03 -7.95232077e-02 1.10218455e-03 1.10137716e-01]
[-8.82392006e-03 7.95232076e-02 5.14389444e-03 1.08734303e-01]]
[[ 0.10033137 -0.6373762 0.54186547 0.07477984]]
[[ 1.76207373e-01 7.88508390e-17 -1.55245191e-02 6.68824780e-01]
[ 1.00331375e-01 6.37376198e-01 -1.45683438e-01 2.02677498e-01]
[-4.22086519e-01 2.64411108e-01 -3.54298666e-01 -5.84745422e-01]
[-4.22086519e-01 -2.64411108e-01 2.46054870e-01 -3.61536693e-01]
[ 1.00331375e-01 -6.37376198e-01 5.41865472e-01 7.47798373e-02]]
[[-0.00882392 -0.07952321 0.00514389 -0.1087343 ]]
[[-5.30688975e-03 -1.29057305e-01 7.13535544e-03 4.24083614e-02]
[-1.08748682e-02 -2.65970035e-10 -5.91313883e-05 1.35425081e-01]
[-5.30688976e-03 1.29057304e-01 9.49226933e-03 4.10385762e-02]
[-8.82392005e-03 7.95232079e-02 1.10218456e-03 -1.10137716e-01]
[-8.82392007e-03 -7.95232074e-02 5.14389443e-03 -1.08734303e-01]]
[[ 0.10033137 0.6373762 0.54186547 -0.07477984]]
[[ 1.76207373e-01 1.23154931e-15 -1.55245191e-02 -6.68824780e-01]
[ 1.00331375e-01 -6.37376198e-01 -1.45683438e-01 -2.02677498e-01]
[-4.22086519e-01 -2.64411108e-01 -3.54298666e-01 5.84745422e-01]
[-4.22086519e-01 2.64411108e-01 2.46054870e-01 3.61536693e-01]
[ 1.00331375e-01 6.37376198e-01 5.41865472e-01 -7.47798373e-02]]
[[-0.13885106 -0.5302854 0.61644478 -0.16862399]]
[[ 2.40083387e-01 -4.61727411e-01 2.95508515e-01 4.09626623e-01]
[ 2.40083387e-01 4.61727411e-01 -1.39956345e-01 6.31332785e-01]
[-1.38851060e-01 5.30285400e-01 -3.54521788e-01 -3.15885370e-01]
[-6.04835944e-01 -3.28395082e-16 -1.82913260e-01 -5.56450048e-01]
[-1.38851060e-01 -5.30285400e-01 6.16444777e-01 -1.68623990e-01]]
[9]:
# Fixed point type can be infered from monodromy eigenvalues
# For each fixed point, corresponding monodromy matrix is computed
# Next, eigenvalues and eigenvectors are computed and combined into pairs
# For each pair of eigenvalues, classification if performed, for example:
# [True] -- E
# [False] -- H
# [True, True] -- EE
# [True, False] -- EH
matrices = jit(vmap(monodromy(order, forward4D), (0, None)))(points, k)
for point, matrix in zip(points, matrices):
es, vs = jax.numpy.linalg.eig(matrix)
es, vs = combine(es, vs)
print(point)
print(es.flatten())
print(classify(es))
print()
[-0.00882392 0.07952321 0.00514389 0.1087343 ]
[0.5602869 +8.28298609e-01j 0.5602869 -8.28298609e-01j 0.99999999+1.28074277e-04j 0.99999999-1.28074277e-04j]
[ True True]
[ 0.10033137 -0.6373762 0.54186547 0.07477984]
[ 1.50322526+0.j 0.66523629+0.j -0.44205249+0.j -2.26217478+0.j]
[False False]
[-0.00882392 -0.07952321 0.00514389 -0.1087343 ]
[0.5602869 +8.28298609e-01j 0.5602869 -8.28298609e-01j 0.99999999+1.28074277e-04j 0.99999999-1.28074277e-04j]
[ True True]
[ 0.10033137 0.6373762 0.54186547 -0.07477984]
[ 1.50322526+0.j 0.66523629+0.j -0.44205249+0.j -2.26217478+0.j]
[False False]
[-0.13885106 -0.5302854 0.61644478 -0.16862399]
[ 0.89755228+0.44090804j 0.89755228-0.44090804j -0.31026879+0.j -3.22301194+0.j ]
[ True False]
[10]:
# Generate trajectory near the fixed point
point, *_ = points
delta = 1.0E-3
point = point + delta
qx, qy, px, py = nest_list(2**14, forward4D)(point, k).T
[11]:
# Plot trajectory
fig = go.Figure(
data=[
go.Scatter3d(
x=qy,
y=py,
z=qx,
mode='markers',
marker=dict(size=1, color=py, colorscale='Viridis')
)
]
)
fig.update_layout(
width=1000, height=800,
scene=dict(xaxis_title='qy', yaxis_title='py', zaxis_title='qx', aspectmode='cube')
)
fig.show()